{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MWPOsk-I8o69"
      },
      "outputs": [],
      "source": [
        "# @title Install code and dependencies {form-width: \"25%\"}\n",
        "!pip install git+https://github.com/google-deepmind/tapnet.git"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dNWBx_DOHSSt"
      },
      "outputs": [],
      "source": [
        "# @title Download Model {form-width: \"25%\"}\n",
        "\n",
        "%mkdir tapnet/checkpoints\n",
        "\n",
        "!wget -P tapnet/checkpoints https://storage.googleapis.com/dm-tapnet/causal_tapir_checkpoint.npy\n",
        "\n",
        "%ls tapnet/checkpoints\n",
        "\n",
        "checkpoint_path = 'tapnet/checkpoints/causal_tapir_checkpoint.npy'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jtTNXUNCHVAL"
      },
      "outputs": [],
      "source": [
        "# @title Imports {form-width: \"25%\"}\n",
        "%matplotlib widget\n",
        "import functools\n",
        "\n",
        "import haiku as hk\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import matplotlib.pyplot as plt\n",
        "import mediapy as media\n",
        "import numpy as np\n",
        "from tqdm import tqdm\n",
        "import tree\n",
        "\n",
        "from tapnet.robotap import tapir_clustering\n",
        "from tapnet.utils import transforms\n",
        "from tapnet.utils import viz_utils"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0J9kVfSuHmqS"
      },
      "outputs": [],
      "source": [
        "# @title Load an Exemplar Video {form-width: \"25%\"}\n",
        "\n",
        "%mkdir tapnet/examplar_videos\n",
        "\n",
        "!wget -P tapnet/examplar_videos https://storage.googleapis.com/dm-tapnet/robotap/for_clustering.mp4\n",
        "\n",
        "video = media.read_video('tapnet/examplar_videos/for_clustering.mp4')\n",
        "height, width = video.shape[1:3]\n",
        "media.show_video(video[::5], fps=10)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7Vjhi4PdJ2W-"
      },
      "outputs": [],
      "source": [
        "# @title Run TAPIR to extract point tracks {form-width: \"25%\"}\n",
        "\n",
        "demo_videos = {\"dummy_id\":video}\n",
        "demo_episode_ids = list(demo_videos.keys())\n",
        "track_dict = tapir_clustering.track_many_points(\n",
        "    demo_videos,\n",
        "    demo_episode_ids,\n",
        "    checkpoint_path,\n",
        "    point_batch_size=1024,\n",
        "    points_per_frame=10,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kU2yqJVTPgg-"
      },
      "outputs": [],
      "source": [
        "# @title Run the clustering {form-width: \"25%\"}\n",
        "\n",
        "clustered = tapir_clustering.compute_clusters(\n",
        "    track_dict['separation_tracks'],\n",
        "    track_dict['separation_visibility'],\n",
        "    track_dict['demo_episode_ids'],\n",
        "    track_dict['video_shape'],\n",
        "    track_dict['query_features'],\n",
        "    max_num_cats=12,\n",
        "    final_num_cats=7,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FCNCAeLVQ0r2"
      },
      "outputs": [],
      "source": [
        "# @title Display the inferred clusters {form-width: \"25%\"}\n",
        "\n",
        "separation_visibility_trim = clustered['separation_visibility']\n",
        "separation_tracks_trim = clustered['separation_tracks']\n",
        "\n",
        "pointtrack_video = viz_utils.plot_tracks_v2(\n",
        "    (demo_videos[demo_episode_ids[0]]).astype(np.uint8),\n",
        "    separation_tracks_trim[demo_episode_ids[0]],\n",
        "    1.0-separation_visibility_trim[demo_episode_ids[0]],\n",
        "    trackgroup=clustered['classes']\n",
        ")\n",
        "media.show_video(pointtrack_video, fps=20)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
